"""
training functions
"""
import time
from math import inf

import torch
from torch.utils.data import DataLoader
from torch.nn import BCEWithLogitsLoss
import torch.nn.functional as F
from tqdm import tqdm
import wandb
import numpy as np

from utils import get_num_samples

from ext import pyext

def auc_loss(logits, y, num_neg=1):
    pos_out = logits[y == 1]
    neg_out = logits[y == 0]
    # hack, should really pair negative and positives in the training set
    if len(neg_out) <= len(pos_out):
        pos_out = pos_out[:len(neg_out)]
    else:
        neg_out = neg_out[:len(pos_out)]
    pos_out = torch.reshape(pos_out, (-1, 1))
    neg_out = torch.reshape(neg_out, (-1, num_neg))
    return torch.square(1 - (pos_out - neg_out)).sum()


def bce_loss(logits, y, num_neg=1):
    return BCEWithLogitsLoss()(logits.view(-1), y.to(torch.float))


def get_loss(loss_str):
    if loss_str == 'bce':
        loss = bce_loss
    elif loss_str == 'auc':
        loss = auc_loss
    else:
        raise NotImplementedError
    return loss


def train(model, optimizer, train_loader, args, device, emb=None):
    print('starting training')
    t0 = time.time()
    model.train()
    total_loss = 0
    data = train_loader.dataset

    path_features = data.path_features

    # hydrate edges
    links = data.links
    labels = torch.tensor(data.labels)
    # neg_weight = 1.0 / (labels.shape[0] / labels.sum() - 1.0)
    neg_weight = labels.sum() / (labels.shape[0]-labels.sum())
    weights = (neg_weight + labels).clamp(max=1.0)

    # sampling
    # train_samples = get_num_samples(args.train_samples, len(labels))
    # sample_indices = torch.randperm(len(labels))[:train_samples]
    # links = links[sample_indices]
    # labels = labels[sample_indices]
    # path_features = path_features[sample_indices]
    # weights = weights[sample_indices]

    if args.wandb:
        wandb.log({"train_total_batches": len(train_loader)})
    batch_processing_times = []
    loader = DataLoader(range(len(links)), args.batch_size, shuffle=True)
    for batch_count, indices in enumerate(tqdm(loader)):
        # do node level things
        if model.node_embedding is not None:
            if args.propagate_embeddings:
                emb = model.propagate_embeddings_func(data.edge_index.to(device))
            else:
                emb = model.node_embedding.weight
        else:
            emb = None
        curr_links = links[indices]
        curr_path_features = path_features[indices].to(device)
        curr_weights = weights[indices].to(device)
        batch_emb = None if emb is None else emb[curr_links].to(device)

        node_features = data.x[curr_links].to(device)

        degrees = data.degrees[curr_links].to(device)
        # if args.use_RA:
        #     ra_indices = sample_indices[indices]
        #     RA = data.RA[ra_indices].to(device)
        # else:
        #     RA = None
        RA = None
        start_time = time.time()
        optimizer.zero_grad()
        logits = model(curr_path_features, node_features, degrees[:, 0], degrees[:, 1], RA, batch_emb)
        # loss = get_loss(args.loss)(logits, labels[indices].squeeze(0).to(device))
        loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels[indices].squeeze(0).to(device).to(torch.float), curr_weights)

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * args.batch_size
        batch_processing_times.append(time.time() - start_time)

    if args.wandb:
        wandb.log({"train_batch_time": np.mean(batch_processing_times)})
        wandb.log({"train_epoch_time": time.time() - t0})

    print(f'training ran in {time.time() - t0}')

    if args.log_features:
        model.log_wandb()

    return total_loss / len(train_loader.dataset)